package io.sqrtqiezi.spark.lagou

import org.apache.spark.sql.SparkSession

import scala.math.{pow, sqrt}
import scala.util.control.Breaks._

// 声明 Iris 结构，方便操作
case class Iris(Id: Int,
                SepalLengthCm: Double,
                SepalWidthCm: Double,
                PetalLengthCm: Double,
                PetalWidthCm: Double,
                Species: String)

object KMeansIris {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .master("local")
      .appName("kmeans iris data")
      .getOrCreate()

    import spark.implicits._

    val DISTANCE_THRESHOLD = 0.001 // 迭代计算时的中心点变化阈值

    // 读取数据文件
    var irisRDD = spark.read
      .option("inferSchema", "true")
      .option("header", "true")
      .csv("lagou-data/iris.csv")
      .as[Iris]
      .rdd
      .map(iris => (0, iris))

    // 随机初始化中心点
    var points = Array(
      (1, 5.1,3.5,1.4,0.2),
      (2, 4.9,3.0,1.4,0.2),
      (3, 4.7,3.2,1.3,0.2)
    )

    var steps = 0 // 记录迭代步骤
    breakable {
      while (true) {
        // 根据中心店距离，对数据进行分类
        irisRDD = irisRDD.map { case (_, iris) => (classify(points)(iris), iris) }

        // 保留上一个中心点的值
        val prePoints = points

        points = irisRDD.groupByKey
          .mapValues(items => {
            // 根据分类，计算新的中心点
            // 1. 记录该分类点的个数
            val count = items.size
            // 2. 合计各个字段的值的总和
            val tuple: (Double, Double, Double, Double) = items
              .map(iris => (iris.SepalLengthCm, iris.SepalWidthCm, iris.PetalLengthCm, iris.PetalWidthCm))
              .reduce((acc, item) => (acc._1 + item._1, acc._2 + item._2, acc._3 + item._3, acc._4 + item._4))
            // 3. 总和除以个数，即为新的中心点位置
            (tuple._1 / count, tuple._2 / count, tuple._3 / count, tuple._4 / count)
          })
          .map { case (key, item) => (key, item._1, item._2, item._3, item._4) }
          .sortBy(_._1)
          .collect

        // 计算上一个中心点和当前中心点的距离之和
        var sumDistance:Double = 0
        for (i <- 0 until 3) {
          sumDistance += sqrt(
            pow(prePoints(i)._1 - points(i)._1, 2) +
              pow(prePoints(i)._2 - points(i)._2, 2) +
              pow(prePoints(i)._3 - points(i)._3, 2) +
              pow(prePoints(i)._4 - points(i)._4, 2)
          )
        }
        println(s"sumDistance : $sumDistance")
        steps += 1

        // 当中心点距离小于阈值的时候，说明中心点不再大范围变化，结束迭代计算
        if (sumDistance < DISTANCE_THRESHOLD) {
          println(s"steps : $steps")
          break
        }
      }
    }

    irisRDD.collect.foreach(println)
  }

  // 分类辅助函数，和 iris 距离最近的中心点，为 iris 的分类
  def classify(points: Array[(Int, Double, Double, Double, Double)])(iris: Iris):Int = {
    points.sortBy({ case (_, sepalLengthCm, sepalWidthCm, petalLengthCm, petalWidthCm) =>
      sqrt(pow(sepalLengthCm - iris.SepalLengthCm, 2) +
        pow(sepalWidthCm - iris.SepalWidthCm, 2) +
        pow(petalLengthCm - iris.PetalLengthCm, 2) +
        pow(petalWidthCm - iris.PetalWidthCm, 2)
      )
    }).toList.head._1
  }
}
